library(R2jags)
library(xlsx)
library(openxlsx)
library(spatstat)
library(dplyr)
library(readr)
require(R2wd)
library(LaplacesDemon) # for Geweke.Diagnostic()
library(parallel)
#library(doParallel)
library(splitstackshape) # For stratified sampling
library(caret) # function createFolds

options(max.print=999999999)

memory.limit(10000)
memory.size(max=T)

priorslist<-function(thedata, varnames){theresult<-vector("list", 2*length(varnames))
wholenames<-varnames
for (i in 1:length(varnames)){wholenames[[i]]<-paste("\\b", varnames[[i]], "\\b", sep="")
theresult[[i]]<-thedata[grepl(wholenames[[i]], dimnames(thedata)[[1]]), "mean"]
theresult[[i+length(varnames)]]<-thedata[grepl(wholenames[[i]], dimnames(thedata)[[1]]), "sd"]}
thenames<-c(paste("m.", varnames, sep=""), paste("sd.", varnames, sep="") )
attributes(theresult)<-list(names=thenames)
return(theresult)}

setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) # set wd to source file location
getwd() # Check directory

outdir<-'irt_estimates/';dir.create(file.path(getwd(),outdir),showWarnings=F)
inpdir<-'raw_indicators/'

'%!in%' <- Negate('%in%')

## JAGS code here
main_model="data{for (i in 1:len){opinionr[i,2]<-round(support[i]*samplesize[i]/100) 
		opinionr[i,1]<-round(neutral[i]*samplesize[i]/100)}}
model{for (i in 1:len){for(j in 1:2){opinionr[i,j]~dbin(p[i,j],samplesize[i])      
		p[i,j]~dbeta(alpha[i,j], beta)              
		alpha[i,j]<--beta*m[i,j]/(m[i,j]-1)         
		m[i,j]<-phi(x[i,j])             
	  x[i,j]<-(mu[t[i]]-diff[qn[i],j])/sqrt((disc[qn[i]])^2+(sigma[t[i]])^2)}}
beta~dunif(0,100)
mu[begin]~dnorm(0,1)
sigma_mu~dunif(0,10)
for (i in 1:nquest){disc[i]~dunif(0,6) 
	for (j in 1:2){diff0[i,j]~dunif(-4,4)}
	diff[i,1:2]<-sort(diff0[i,])}
for (i in begin:end){sigma[i]~dunif(0,10)}
for (i in begin:(end-1)){mu_raw[i]~dnorm(0,1)}
for (i in (begin+1):end){mu[i]~dnorm(mu[i-1]+sigma_mu*mu_raw[i-1],100)}}"
## Save to temp file
fileConn=file("main_model.temp")
writeLines(main_model,fileConn)
close(fileConn)

model_fit="data{for (i in 1:len){opinionr[i,2]<-round(support[i]*samplesize[i]/100)
	opinionp[i,2]<-support[i] 
	opinionr[i,1]<-round(neutral[i]*samplesize[i]/100)
	opinionp[i,1]<-neutral[i]}}
model{for (i in 1:len){for(j in 1:2){sqresid[i,j]<-(opinionp[i,j]-opinionp.sim[i,j])^2
	opinionp.exp[i,j]<-m[i,j]*100
	opinionp.sim[i,j]<-opinionr.sim[i,j]/samplesize[i]*100
	opinionr.sim[i,j]~dbin(p[i,j],samplesize[i]) 
	p[i,j]~dbeta(alpha[i,j], beta)              
	alpha[i,j]<--beta*m[i,j]/(m[i,j]-1)         
	m[i,j]<-phi(x[i,j])                   
	x[i,j]<-(mu[t[i]]-diff[qn[i],j])/sqrt((disc[qn[i]])^2+(sigma[t[i]])^2)}}
beta<-m.beta 
for (i in 1:nquest){disc[i]<-m.disc[i]
	for (j in 1:2){diff0[i,j]<-m.diff[i+nquest*(j-1)]}
	diff[i,1:2]<-sort(diff0[i,])
	diff.support[i]<-diff[i,2]}
for (i in begin:end){sigma[i]<-m.sigma[i-(begin-1)]
	mu[i]<-m.mu[i-(begin-1)]}}"
## Save to temp file
fileConn=file("model_fit.temp")
writeLines(model_fit,fileConn)
close(fileConn)

##################################################################################################

# FIT OF IRT MODEL
#-----------------

# Import data 

data<-read.table(paste0(inpdir,'EU_irt.txt'),sep=';',header=T)  # Read in data
data<-subset(data,as.numeric(gsub(' S1| S2','',data$t))>=1973)
data$t<-as.numeric(data$t)
data$asked<-NULL
data<-suppressMessages(left_join(data,data %>% group_by(question) %>% tally() %>% as.data.frame()))
names(data)[names(data)=='n']<-'asked'  
data<-subset(data,data$asked>2)
data<-subset(data,select=c('t','question','support','neutral','samplesize','asked'))
rownames(data) <- 1:nrow(data) # Reset rownumbers if subsetting dataframe

cv.data<-data

# Start loop

for (f in 1){
  print(paste0('Attempt N. ',f))
  
  # Create train and test data
  
  nfolds<-5
  Fold<-data.frame(seq(nfolds));names(Fold)[names(Fold)=='seq.k.']<-'Fold';set.seed(f);folds<-createFolds(c(cv.data$question),k=nfolds)
  
  for (k in 1:nfolds){
    print(paste0('Fold N. ',k))
    train<-cv.data[-c(folds[[k]]),];test<-cv.data[c(folds[[k]]),]
    
    train<-train[order(train$question),];train$qn<-as.numeric(factor(train$question))
    train_qnumb<-unique(train[,c('question','qn')]);colnames(train_qnumb)<-c('question','train_qn')
    
    print(paste0(nrow(test),' observations ',length(unique(test$question)),' questions and ',length(unique(test$t)),' time points in test data. ',nrow(train),' observations in train data.'))
    
    # Reset qn in test
    test<-test[order(test$question),];test$qn<-as.numeric(factor(test$question))
    rownames(test)<-1:nrow(test)
    test$rn<-row.names(test)
    test_qnumb<-unique(test[,c('question','qn')]);colnames(test_qnumb)<-c('question','test_qn')
    qnumb<-left_join(train_qnumb,test_qnumb)
    
    b<-list(len=nrow(train),nquest=length(unique(train$question)),begin=min(train$t),end=max(train$t))
    irt.train <-c(train,b) # <- Train data
    
    # Estimate IRT model
    
    start_time<-Sys.time();print(start_time)
    irtfit<-jags.parallel(data=irt.train,inits=NULL,parameters.to.save=c('mu','sigma',"diff","disc","beta"),
                          n.burnin=10000,n.iter=100000,n.thin=1,n.chains=3,
                          model.file="main_model.temp")
    results.irt<-irtfit$BUGSoutput$summary
    fit.mcmc<-as.mcmc(irtfit);print(gelman.diag(fit.mcmc,autoburnin=F)$mpsrf) 
    priors.irt<-priorslist(results.irt, list("disc","mu","diff","sigma","beta"))
    
    # Remove, from estimated parameters, questions not in test set
    for (q in unique(sort(train$question,decreasing=T))){if (q %!in% unique(test$question)){
      i=as.numeric(unique(train[which(train$question==q),]$qn))
      priors.irt$m.disc<-priors.irt$m.disc[-i]
      priors.irt$sd.disc<-priors.irt$sd.disc[-i]
      priors.irt$m.diff<-priors.irt$m.diff[-(i+length(unique(train$qn)))]
      priors.irt$sd.diff<-priors.irt$sd.diff[-(i+length(unique(train$qn)))]}}
    for (q in unique(sort(train$question,decreasing=T))){if (q %!in% unique(test$question)){
      i=as.numeric(unique(train[which(train$question==q),]$qn))
      priors.irt$m.diff<-priors.irt$m.diff[-i]
      priors.irt$sd.diff<-priors.irt$sd.diff[-i]}}
    
    # Reset cell count for question parameters
    for (e in 1:length(priors.irt$m.disc)){names(priors.irt$m.disc)[e]<-paste0('disc[',as.character(e),']')}
    for (e in 1:length(priors.irt$sd.disc)){names(priors.irt$sd.disc)[e]<-paste0('disc[',as.character(e),']')}
    for (e in 1:length(unique(test$question))){
      names(priors.irt$m.diff)[e]<-paste0('diff[',as.character(e),',1]')
      names(priors.irt$m.diff)[e+length(unique(test$question))]<-paste0('diff[',as.character(e),',2]')
      names(priors.irt$sd.diff)[e]<-paste0('diff[',as.character(e),',1]')
      names(priors.irt$sd.diff)[e+length(unique(test$question))]<-paste0('diff[',as.character(e),',2]')}
    
    b<-list(len=nrow(test),nquest=length(unique(test$question)),begin=min(test$t),end=max(test$t))
    irt.test<-c(test,b) 
    irt.test<-c(irt.test,priors.irt) #<-Test data
    
    # Replicate IRT model
    rep.irtfit<-jags.parallel(data=irt.test,inits=NULL,parameters.to.save=c('mu','sigma',"diff.support"),
                              n.burnin=10000,n.iter=100000,n.thin=1,n.chains=3,
                              model.file="model_fit.temp",DIC=F)
    duration<-Sys.time()-start_time;print(duration)
    
    test.results.irt<-rep.irtfit$BUGSoutput$summary  # Extract estimates from results
    diff.support<-test.results.irt[grepl("diff.support", dimnames(test.results.irt)[[1]]), "mean"]
    
    # Get predictions from IRT model
    rep.mu<-array(0,irt.test$len)
    rep.sigma<-array(0,irt.test$len)
    for (i in 1:irt.test$len){rep.mu[i]<-priors.irt$m.mu[irt.test$t[i]-(irt.test$begin-1)]-diff.support[irt.test$qn[i]]
      rep.sigma[i]<-((priors.irt$m.disc[irt.test$qn[i]])^2+(priors.irt$m.sigma[irt.test$t[i]-(irt.test$begin-1)])^2)^.5}
    
    pred.opinionp<-pnorm(rep.mu/rep.sigma,0,1)*100
    real.opinion<-data.frame(irt.test$support)
    real.opinion$rn<-row.names(real.opinion)
    pred.opinion<-data.frame(pred.opinionp)
    pred.opinion$rn<-row.names(pred.opinion)
    
    predictions.irt<-left_join(real.opinion,pred.opinion)
    predictions.irt<-left_join(predictions.irt,data.frame(test[c('rn','t','qn','question')]))
    predictions.irt$rn<-as.numeric(predictions.irt$rn)
    colnames(predictions.irt)<-c('real.opinion','rn','pred.opinion','t','qn','question')
    
    # Calculate mean response to each question
    mean.quest<-array(0,irt.test$nquest)
    for (i in 1:irt.test$nquest){
      ques<-as.character(unique(test[which(test$qn==i),]$question))
      mean.quest[i]<-mean(train[which(as.character(train$question)==ques),]$support)}
    
    means<-data.frame(mean.quest)
    means$qn<-as.numeric(row.names(means))
    predictions.irt<-left_join(predictions.irt,means)
    
    predictions.irt$ss.mean.qn<-predictions.irt$ss.mean.qn<-(predictions.irt$real.opinion-predictions.irt$mean.quest)^2

    predictions.irt$sqresid<-((predictions.irt$real.opinion-predictions.irt$pred.opinion)^2)
    fold_results<-c(sqrt(sum(predictions.irt$ss.mean.qn)/nrow(predictions.irt)),sqrt(mean(predictions.irt$sqresid)),1-sum(predictions.irt$sqresid)/sum(predictions.irt$ss.mean.qn))
    if (k==1){cv_folds<-fold_results} else {cv_folds<-rbind(cv_folds,fold_results);print(mean(cv_folds[,3]))}}
  if (f==1){cv_results<-cv_folds} else {cv_results<-rbind(cv_results,cv_folds)}}

sink(file='irt_estimates/cross_validation_irt_model.txt')
print(cv_results)
print('All models estimated with 100,000 iterations')
print(paste0('ITEM MEANS: ',mean(cv_results[,1])))
print(paste0('ROOT MSE: ',mean(cv_results[,2])))
print(paste0('ADJ. R SQUARED: ',mean(cv_results[,3]))) # 0.16012
sink() 